def visualize_model_fit(model_name, experiments=None):
"""
Visualize model fit: observed vs predicted reproduction biases.
Parameters:
-----------
model_name : str
Name of the model (e.g., 'FreeParameters', 'Experimentwise')
experiments : list
List of experiments to visualize (default: all 5 experiments)
"""
if experiments is None:
experiments = ['Encoding', 'Reproduction', 'Baseline', 'Both', 'Both_gap']
# Collect all data into a single DataFrame
all_data = []
for exp in experiments:
mdat_file = os.path.join(cpath, 'data', model_name, f'{exp}_mdat.csv')
if not os.path.exists(mdat_file):
continue
mdat = pd.read_csv(mdat_file)
# Add experiment column
mdat['Experiment'] = exp
# For Both_gap, merge gap information
if exp == 'Both_gap':
exp_data = expData[expData['ExpName'] == exp].copy()
# Create gap label
mdat['gap_label'] = mdat['gap'].map({mdat['gap'].min(): 'short', mdat['gap'].max(): 'long'})
else:
mdat['gap_label'] = 'none'
all_data.append(mdat)
# Combine all data
combined_data = pd.concat(all_data, ignore_index=True)
# Aggregate: compute mean and SEM across subjects
agg_data = combined_data.groupby(['Experiment', 'WMSize', 'curDur', 'gap_label']).agg({
'repErr': ['mean', 'sem'],
'predErr': 'mean',
'repCV': ['mean', 'sem'],
'predCV': 'mean'
}).reset_index()
# Flatten column names
agg_data.columns = ['Experiment', 'WMSize', 'curDur', 'gap_label',
'repErr_mean', 'repErr_sem', 'predErr',
'repCV_mean', 'repCV_sem', 'predCV']
# Create figure - 3 rows layout matching template
# Note: Row 1 & 2 share duration x-axis, Row 3 has bias x-axis (don't share)
subplot_width = 2.5
subplot_height = 2.2
fig, axes = plt.subplots(3, len(experiments),
figsize=(subplot_width*len(experiments), subplot_height*3))
if len(experiments) == 1:
axes = axes.reshape(-1, 1)
# Color palette for memory loads
colors = ['#d9d9d9', '#838383', '#3b3b3b']
for idx, exp in enumerate(experiments):
exp_data = agg_data[agg_data['Experiment'] == exp]
# Get axes for this experiment (column)
ax_bias = axes[0, idx]
ax_cv = axes[1, idx]
ax_scatter = axes[2, idx]
if len(exp_data) == 0:
for ax in [ax_bias, ax_cv, ax_scatter]:
ax.text(0.5, 0.5, f'No data\nfor {exp}', ha='center', va='center', transform=ax.transAxes)
ax_bias.set_title(exp, fontsize=10)
continue
# === ROW 1: Bias plots ===
if exp == 'Both_gap':
for gap_label in ['short', 'long']:
gap_data = exp_data[exp_data['gap_label'] == gap_label]
linestyle = '-' if gap_label == 'short' else '--'
marker = 'o' if gap_label == 'short' else 's'
for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
wm_data = gap_data[gap_data['WMSize'] == wm].sort_values('curDur')
# Observed with error bars
ax_bias.errorbar(wm_data['curDur'], wm_data['repErr_mean'],
yerr=wm_data['repErr_sem'],
fmt=marker, color=colors[wm_idx], markersize=5,
alpha=0.6, capsize=2, zorder=3)
# Predicted
ax_bias.plot(wm_data['curDur'], wm_data['predErr'],
color=colors[wm_idx], linestyle=linestyle, linewidth=1.5, zorder=2)
else:
for wm_idx, wm in enumerate(sorted(exp_data['WMSize'].unique())):
wm_data = exp_data[exp_data['WMSize'] == wm].sort_values('curDur')
# Observed with error bars
ax_bias.errorbar(wm_data['curDur'], wm_data['repErr_mean'],
yerr=wm_data['repErr_sem'],
fmt='o', color=colors[wm_idx], markersize=5,
alpha=0.6, capsize=2, zorder=3)
# Predicted
ax_bias.plot(wm_data['curDur'], wm_data['predErr'],
color=colors[wm_idx], linestyle='-', linewidth=1.5, zorder=2)
ax_bias.axhline(y=0, color='black', linestyle='--', linewidth=0.5, zorder=1)
ax_bias.set_title(exp, fontsize=10)
ax_bias.set_xlim(0.3, 1.8)
ax_bias.tick_params(labelsize=7)
if idx == 0:
ax_bias.set_ylabel('Reproduction bias (s)', fontsize=8)
# === ROW 2: CV plots ===
if exp == 'Both_gap':
for gap_label in ['short', 'long']:
gap_data = exp_data[exp_data['gap_label'] == gap_label]
linestyle = '-' if gap_label == 'short' else '--'
marker = 'o' if gap_label == 'short' else 's'
for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
wm_data = gap_data[gap_data['WMSize'] == wm].sort_values('curDur')
# Observed CV with error bars
ax_cv.errorbar(wm_data['curDur'], wm_data['repCV_mean'],
yerr=wm_data['repCV_sem'],
fmt=marker, color=colors[wm_idx], markersize=5,
alpha=0.6, capsize=2, zorder=3)
# Predicted CV
ax_cv.plot(wm_data['curDur'], wm_data['predCV'],
color=colors[wm_idx], linestyle=linestyle, linewidth=1.5, zorder=2)
else:
for wm_idx, wm in enumerate(sorted(exp_data['WMSize'].unique())):
wm_data = exp_data[exp_data['WMSize'] == wm].sort_values('curDur')
# Observed CV with error bars
ax_cv.errorbar(wm_data['curDur'], wm_data['repCV_mean'],
yerr=wm_data['repCV_sem'],
fmt='o', color=colors[wm_idx], markersize=5,
alpha=0.6, capsize=2, zorder=3)
# Predicted CV
ax_cv.plot(wm_data['curDur'], wm_data['predCV'],
color=colors[wm_idx], linestyle='-', linewidth=1.5, zorder=2)
ax_cv.set_xlim(0.3, 1.8)
ax_cv.tick_params(labelsize=7)
if idx == 0:
ax_cv.set_ylabel('Coefficient of Variation', fontsize=8)
# === ROW 3: Scatter plot (individual data points) ===
exp_raw_data = combined_data[combined_data['Experiment'] == exp]
if exp == 'Both_gap':
for gap_label in ['short', 'long']:
gap_data = exp_raw_data[exp_raw_data['gap_label'] == gap_label]
marker = 'o' if gap_label == 'short' else 's'
for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
wm_data = gap_data[gap_data['WMSize'] == wm]
ax_scatter.scatter(wm_data['repErr'], wm_data['predErr'],
color=colors[wm_idx], marker=marker,
s=15, alpha=0.5, edgecolors='none')
else:
for wm_idx, wm in enumerate(sorted(exp_raw_data['WMSize'].unique())):
wm_data = exp_raw_data[exp_raw_data['WMSize'] == wm]
ax_scatter.scatter(wm_data['repErr'], wm_data['predErr'],
color=colors[wm_idx], marker='o',
s=15, alpha=0.5, edgecolors='none')
# Add diagonal line and set axis limits centered around 0
# Determine appropriate limits based on data
all_obs = exp_raw_data['repErr'].values
all_pred = exp_raw_data['predErr'].values
combined_vals = np.concatenate([all_obs, all_pred])
max_abs = max(abs(combined_vals.min()), abs(combined_vals.max()))
lim = max_abs * 1.1 # Add 10% margin
ax_scatter.plot([-lim, lim], [-lim, lim], 'k--', linewidth=0.5, zorder=1)
ax_scatter.axhline(y=0, color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
ax_scatter.axvline(x=0, color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
ax_scatter.set_xlim(-lim, lim)
ax_scatter.set_ylim(-lim, lim)
ax_scatter.set_xlabel('Observed Bias (s)', fontsize=8)
ax_scatter.tick_params(labelsize=7)
ax_scatter.set_aspect('equal', adjustable='box')
if idx == 0:
ax_scatter.set_ylabel('Predicted Bias (s)', fontsize=8)
ax_scatter.text(-0.35, 0.5, 'c', transform=ax_scatter.transAxes,
fontsize=12, fontweight='bold', va='center')
# Add row labels on the left
if idx == 0:
ax_bias.text(-0.35, 0.5, 'a', transform=ax_bias.transAxes,
fontsize=12, fontweight='bold', va='center')
ax_cv.text(-0.35, 0.5, 'b', transform=ax_cv.transAxes,
fontsize=12, fontweight='bold', va='center')
# Legend - placed at bottom for better layout
from matplotlib.lines import Line2D
# Memory load legend (always present)
legend_elements_load = [
Line2D([0], [0], color=colors[0], lw=2, label='low'),
Line2D([0], [0], color=colors[1], lw=2, label='medium'),
Line2D([0], [0], color=colors[2], lw=2, label='high')
]
if 'Both_gap' in experiments:
# Gap legend (only for Both_gap experiment)
legend_elements_gap = [
Line2D([0], [0], color='gray', lw=2, linestyle='-', marker='o', markersize=5, label='short'),
Line2D([0], [0], color='gray', lw=2, linestyle='--', marker='s', markersize=5, label='long')
]
# Create two legends side by side at the bottom
leg1 = fig.legend(handles=legend_elements_load, loc='lower center',
bbox_to_anchor=(0.35, -0.02), frameon=False,
title='Memory Load', ncol=3, fontsize=8, title_fontsize=9)
leg2 = fig.legend(handles=legend_elements_gap, loc='lower center',
bbox_to_anchor=(0.68, -0.02), frameon=False,
title='Gap', ncol=2, fontsize=8, title_fontsize=9)
fig.add_artist(leg1) # Add first legend back since second overwrites it
else:
# Single legend for memory load only
fig.legend(handles=legend_elements_load, loc='lower center',
bbox_to_anchor=(0.5, -0.02), frameon=False,
title='Memory Load', ncol=3, fontsize=8, title_fontsize=9)
plt.tight_layout()
plt.subplots_adjust(bottom=0.06, hspace=0.3) # Adjust spacing for 3 rows
plt.savefig(os.path.join(OUTPUT_PATH, f'{model_name}_model_fit.png'), dpi=300, bbox_inches='tight')
plt.show()
# Visualize FreeParameters model fit
visualize_model_fit('FreeParameters')